Skip to content

Conversation

@Chengqian-Zhang
Copy link
Collaborator

@Chengqian-Zhang Chengqian-Zhang commented Aug 27, 2025

In finetuing process, the computation of fitting stat is skipped in previous code. There are two situations:

  1. Finetuning from pretrained model's branch: it means pretrained model also has fparam or aparam which has the same meaning of finetuning task. The key fparam_avg/fparam_inv_std/ aparam_avg/aparam_inv_std load from the pretrained model. It is correct.
  2. Finetuning using RANDOM fitting. The fitting stat should be calculated in this situation. But the computation of fitting stat is skipped now. There is some error.

Summary by CodeRabbit

  • New Features

    • Automatic computation of input statistics is now performed during bias adjustment in "set-by-statistic" mode; public API extended to support computing fitting input statistics.
  • Tests

    • Training tests now compare additional parameter categories to the random-finetuned baseline.
    • Added test-only batching configuration for data-statistics and new unit tests validating fitting input-statistics calculation.

@Chengqian-Zhang Chengqian-Zhang marked this pull request as draft August 27, 2025 08:19
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 27, 2025

📝 Walkthrough

Walkthrough

Adds a public hook compute_fitting_input_stat to BaseAtomicModel and DPAtomicModel, wires it into model creation when bias_adjust_mode == "set-by-statistic", implements compute_input_stats in GeneralFitting, and updates tests to set model.data_stat_nbatch and extend state-dict comparisons to include fparam/aparam.

Changes

Cohort / File(s) Summary
Base atomic model interfaces
deepmd/pd/model/atomic_model/base_atomic_model.py, deepmd/pt/model/atomic_model/base_atomic_model.py
Add public placeholder method compute_fitting_input_stat(self, sample_merged: Union[Callable[[], list[dict]], list[dict]]) -> None with docstring; no-op implementation.
DP atomic model implementations
deepmd/pd/model/atomic_model/dp_atomic_model.py, deepmd/pt/model/atomic_model/dp_atomic_model.py
Add compute_fitting_input_stat(...) that delegates to self.fitting_net.compute_input_stats(sample_merged, protection=self.data_stat_protect) and replace direct calls to fitting_net.compute_input_stats with this helper in compute_or_load_stat flows.
General fitting logic
deepmd/dpmodel/fitting/general_fitting.py
Add compute_input_stats(self, merged: Union[Callable[[], list[dict]], list[dict]], protection: float = 1e-2) -> None computing per-parameter mean/std for fparam and aparam (supports lazy callable input), updates stored averages and inverse stds with protection/clipping.
Model creation hooks
deepmd/pd/model/model/make_model.py, deepmd/pt/model/model/make_model.py
In change_out_bias, when bias_adjust_mode == "set-by-statistic", call self.atomic_model.compute_fitting_input_stat(merged) after bias adjustment.
Tests — training
source/tests/pd/test_training.py, source/tests/pt/test_training.py
Set self.config["model"]["data_stat_nbatch"] = 100 in TestFparam setups; broaden state-dict comparisons to include keys containing fparam or aparam when comparing to random-finetuned baseline.
Tests — fitting stats
source/tests/common/dpmodel/test_fitting_stat.py
Add unit test validating compute_input_stats: generate synthetic data, compute brute-force fparam/aparam stats, call compute_input_stats, and assert stored averages and inverse stds match expected results.

Sequence Diagram(s)

sequenceDiagram
    participant MakeModel as make_model
    participant AtomicModel as atomic_model
    participant FittingNet as fitting_net

    MakeModel->>AtomicModel: change_out_bias(merged, bias_adjust_mode)
    alt bias_adjust_mode == "set-by-statistic"
        MakeModel->>AtomicModel: compute_fitting_input_stat(merged)
        AtomicModel->>FittingNet: compute_input_stats(sample_merged, protection)
        FittingNet-->>AtomicModel: updated stats (fparam_avg, fparam_inv_std, aparam_avg, aparam_inv_std)
    else
        Note right of MakeModel: no fitting input-stat computation
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Review attention:
    • correctness and numeric safety of protection/clipping logic in compute_input_stats.
    • lazy Callable handling and memory behavior when aggregating samples.
    • test assertions in test_fitting_stat.py for numerical tolerances and PD/PT parity.
    • call-site replacements to ensure previous semantics are preserved.

Possibly related PRs

Suggested reviewers

  • wanghan-iapcm
  • njzjz
  • iProzd

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.61% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main fix: computing fitting statistics during random fitting in the finetuning process, which aligns with the core changeset across all modified files.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Nov 7, 2025

Codecov Report

❌ Patch coverage is 92.30769% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 84.23%. Comparing base (25fa707) to head (5ebd7e3).

Files with missing lines Patch % Lines
deepmd/pd/model/atomic_model/base_atomic_model.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##            devel    #4928   +/-   ##
=======================================
  Coverage   84.23%   84.23%           
=======================================
  Files         709      709           
  Lines       70078    70092   +14     
  Branches     3619     3619           
=======================================
+ Hits        59032    59044   +12     
- Misses       9880     9883    +3     
+ Partials     1166     1165    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@Chengqian-Zhang Chengqian-Zhang marked this pull request as ready for review November 7, 2025 09:32
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7778e2e and 59c50c8.

📒 Files selected for processing (8)
  • deepmd/pd/model/atomic_model/base_atomic_model.py (1 hunks)
  • deepmd/pd/model/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/pd/model/model/make_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/base_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (2 hunks)
  • deepmd/pt/model/model/make_model.py (1 hunks)
  • source/tests/pd/test_training.py (2 hunks)
  • source/tests/pt/test_training.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

📄 CodeRabbit inference engine (AGENTS.md)

Always run ruff check . and ruff format . before committing changes to Python code

Files:

  • deepmd/pt/model/atomic_model/dp_atomic_model.py
  • deepmd/pt/model/atomic_model/base_atomic_model.py
  • source/tests/pd/test_training.py
  • deepmd/pd/model/atomic_model/dp_atomic_model.py
  • deepmd/pd/model/atomic_model/base_atomic_model.py
  • deepmd/pd/model/model/make_model.py
  • deepmd/pt/model/model/make_model.py
  • source/tests/pt/test_training.py
🧠 Learnings (2)
📚 Learning: 2025-09-18T11:37:10.532Z
Learnt from: CR
Repo: deepmodeling/deepmd-kit PR: 0
File: AGENTS.md:0-0
Timestamp: 2025-09-18T11:37:10.532Z
Learning: Applies to source/tests/tf/test_dp_test.py : Keep the core TensorFlow test `source/tests/tf/test_dp_test.py` passing; use it for quick validation

Applied to files:

  • source/tests/pd/test_training.py
  • source/tests/pt/test_training.py
📚 Learning: 2024-09-19T04:25:12.408Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.

Applied to files:

  • source/tests/pd/test_training.py
  • source/tests/pt/test_training.py
🧬 Code graph analysis (6)
deepmd/pt/model/atomic_model/dp_atomic_model.py (3)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_stat (406-424)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
  • compute_fitting_stat (496-512)
deepmd/pt/model/task/fitting.py (1)
  • compute_input_stats (78-157)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_stat (336-354)
deepmd/pd/model/atomic_model/dp_atomic_model.py (3)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
  • compute_fitting_stat (518-534)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_stat (336-354)
deepmd/pd/model/task/fitting.py (1)
  • compute_input_stats (75-160)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_stat (406-424)
deepmd/pd/model/model/make_model.py (3)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
  • compute_fitting_stat (518-534)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_stat (406-424)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_stat (336-354)
deepmd/pt/model/model/make_model.py (2)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
  • compute_fitting_stat (496-512)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_stat (336-354)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
🔇 Additional comments (10)
deepmd/pd/model/model/make_model.py (1)

231-232: LGTM! Correct invocation of compute_fitting_stat.

The conditional call to compute_fitting_stat after change_out_bias appropriately addresses the PR objective of computing fitting statistics when using random fitting in finetuning (set-by-statistic mode). The implementation correctly reuses the merged data.

source/tests/pd/test_training.py (2)

92-100: LGTM! Appropriate test assertion for fparam/aparam keys.

The broadened exclusion condition correctly validates that fparam and aparam statistics are preserved during random finetuning, aligning with the PR's objective to compute fitting statistics properly.


197-197: LGTM! Configuration for data statistics batching.

Adding data_stat_nbatch = 100 appropriately exercises the data statistics batching behavior that's central to this PR's fitting statistics computation.

deepmd/pt/model/atomic_model/base_atomic_model.py (1)

496-512: LGTM! Appropriate placeholder for PT base atomic model.

The no-op implementation is correct for the base class, allowing derived classes to provide concrete implementations. The documentation correctly references torch.Tensor for the PyTorch path.

deepmd/pt/model/model/make_model.py (1)

235-236: LGTM! Correct invocation of compute_fitting_stat in PT path.

The implementation mirrors the PD path and correctly invokes compute_fitting_stat when bias_adjust_mode is "set-by-statistic", addressing the PR's objective for the PyTorch path.

source/tests/pt/test_training.py (2)

95-103: LGTM! Test assertions align with PD path.

The broadened exclusion condition for fparam/aparam keys correctly validates the new fitting statistics computation during random finetuning in the PyTorch path.


263-263: LGTM! Configuration mirrors PD test setup.

Setting data_stat_nbatch = 100 appropriately exercises data statistics batching in the PyTorch path, consistent with the PD tests.

deepmd/pt/model/atomic_model/dp_atomic_model.py (3)

8-8: LGTM! Union import for type hints.

The Union import is correctly added to support the type hints for the new compute_fitting_stat method signature.


332-332: LGTM! Refactored to use compute_fitting_stat.

Good refactoring that centralizes fitting statistics computation through the new compute_fitting_stat method, improving code organization and maintainability.


336-354: LGTM! Proper implementation of compute_fitting_stat.

The method correctly delegates to fitting_net.compute_input_stats with the data_stat_protect parameter, providing a clean interface for computing fitting statistics from packed data.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 59c50c8 and 646977c.

📒 Files selected for processing (2)
  • deepmd/pd/model/atomic_model/base_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

📄 CodeRabbit inference engine (AGENTS.md)

Always run ruff check . and ruff format . before committing changes to Python code

Files:

  • deepmd/pd/model/atomic_model/base_atomic_model.py
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
🧬 Code graph analysis (2)
deepmd/pd/model/atomic_model/base_atomic_model.py (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_stat (336-354)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_stat (406-424)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
  • compute_fitting_stat (496-512)
deepmd/pt/model/task/fitting.py (1)
  • compute_input_stats (78-157)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (28)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (rocm, rocm)
🔇 Additional comments (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)

8-8: LGTM: Import addition supports new method signature.

The Union import is necessary for the Union[Callable[[], list[dict]], list[dict]] type hint in the new compute_fitting_stat method.


332-332: Good refactoring to use the public API.

Calling self.compute_fitting_stat(wrapped_sampler) instead of directly accessing self.fitting_net.compute_input_stats establishes a consistent public interface and enables proper encapsulation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)

496-512: LGTM! Clean API extension for computing fitting statistics.

The method provides a well-documented hook for derived classes to compute fitting statistics during finetuning. The no-op default (pass) is appropriate since not all atomic model types require fitting statistics computation.

Optional: Consider adding a clarifying comment

To make the intent clearer, you could add a brief comment before the pass statement:

         the lazy function helps by only sampling once.
         """
+        # No-op in base class; derived classes override if fitting statistics are needed.
         pass

Reminder: Run code quality checks

As per coding guidelines, ensure you run the following before committing:

ruff check .
ruff format .
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 646977c and 52d29ef.

📒 Files selected for processing (4)
  • deepmd/pd/model/atomic_model/base_atomic_model.py (1 hunks)
  • deepmd/pd/model/atomic_model/dp_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/base_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • deepmd/pd/model/atomic_model/base_atomic_model.py
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
  • deepmd/pd/model/atomic_model/dp_atomic_model.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

📄 CodeRabbit inference engine (AGENTS.md)

Always run ruff check . and ruff format . before committing changes to Python code

Files:

  • deepmd/pt/model/atomic_model/base_atomic_model.py
🧬 Code graph analysis (1)
deepmd/pt/model/atomic_model/base_atomic_model.py (2)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_stat (406-424)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_stat (336-354)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)

@wanghan-iapcm
Copy link
Collaborator

Limitation: the imput stat is not implemented in the python backend.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
source/tests/common/dpmodel/test_fitting_stat.py (1)

14-66: Consider renaming helper functions to remove _pt suffix.

The helper functions _make_fake_data_pt, _brute_fparam_pt, and _brute_aparam_pt use a _pt suffix that typically denotes PyTorch-related code. However, these functions are in the dpmodel test directory and use NumPy arrays, not PyTorch tensors.

To avoid confusion, consider renaming them to remove the _pt suffix:

  • _make_fake_data_pt_make_fake_data
  • _brute_fparam_pt_brute_fparam
  • _brute_aparam_pt_brute_aparam

Apply this diff to rename the functions:

-def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds):
+def _make_fake_data(sys_natoms, sys_nframes, avgs, stds):
     merged_output_stat = []
     nsys = len(sys_natoms)
     ndof = len(avgs)
     for ii in range(nsys):
         sys_dict = {}
         tmp_data_f = []
         tmp_data_a = []
         for jj in range(ndof):
             rng = np.random.default_rng(2025 * ii + 220 * jj)
             tmp_data_f.append(
                 rng.normal(loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], 1))
             )
             rng = np.random.default_rng(220 * ii + 1636 * jj)
             tmp_data_a.append(
                 rng.normal(
                     loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], sys_natoms[ii])
                 )
             )
         tmp_data_f = np.transpose(tmp_data_f, (1, 2, 0))
         tmp_data_a = np.transpose(tmp_data_a, (1, 2, 0))
         sys_dict["fparam"] = tmp_data_f
         sys_dict["aparam"] = tmp_data_a
         merged_output_stat.append(sys_dict)
     return merged_output_stat


-def _brute_fparam_pt(data, ndim):
+def _brute_fparam(data, ndim):
     adata = [ii["fparam"] for ii in data]
     all_data = []
     for ii in adata:
         tmp = np.reshape(ii, [-1, ndim])
         if len(all_data) == 0:
             all_data = np.array(tmp)
         else:
             all_data = np.concatenate((all_data, tmp), axis=0)
     avg = np.average(all_data, axis=0)
     std = np.std(all_data, axis=0)
     return avg, std


-def _brute_aparam_pt(data, ndim):
+def _brute_aparam(data, ndim):
     adata = [ii["aparam"] for ii in data]
     all_data = []
     for ii in adata:
         tmp = np.reshape(ii, [-1, ndim])
         if len(all_data) == 0:
             all_data = np.array(tmp)
         else:
             all_data = np.concatenate((all_data, tmp), axis=0)
     avg = np.average(all_data, axis=0)
     std = np.std(all_data, axis=0)
     return avg, std

Then update the test to use the renamed functions:

     def test(self) -> None:
         descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16)
         fitting = EnergyFittingNet(
             descrpt.get_ntypes(),
             descrpt.get_dim_out(),
             neuron=[240, 240, 240],
             resnet_dt=True,
             numb_fparam=3,
             numb_aparam=3,
         )
         avgs = [0, 10, 100]
         stds = [2, 0.4, 0.00001]
         sys_natoms = [10, 100]
         sys_nframes = [5, 2]
-        all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds)
-        frefa, frefs = _brute_fparam_pt(all_data, len(avgs))
-        arefa, arefs = _brute_aparam_pt(all_data, len(avgs))
+        all_data = _make_fake_data(sys_natoms, sys_nframes, avgs, stds)
+        frefa, frefs = _brute_fparam(all_data, len(avgs))
+        arefa, arefs = _brute_aparam(all_data, len(avgs))
         fitting.compute_input_stats(all_data, protection=1e-2)
         frefs_inv = 1.0 / frefs
         arefs_inv = 1.0 / arefs
         frefs_inv[frefs_inv > 100] = 100
         arefs_inv[arefs_inv > 100] = 100
         np.testing.assert_almost_equal(frefa, fitting.fparam_avg)
         np.testing.assert_almost_equal(
             frefs_inv, fitting.fparam_inv_std
         )
         np.testing.assert_almost_equal(arefa, fitting.aparam_avg)
         np.testing.assert_almost_equal(
             arefs_inv, fitting.aparam_inv_std
         )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3af617c and c6af0e0.

📒 Files selected for processing (3)
  • deepmd/dpmodel/fitting/general_fitting.py (2 hunks)
  • deepmd/pd/model/atomic_model/base_atomic_model.py (1 hunks)
  • source/tests/common/dpmodel/test_fitting_stat.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
source/tests/common/dpmodel/test_fitting_stat.py (3)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
  • get_ntypes (628-629)
source/tests/tf/common.py (1)
  • numb_aparam (909-910)
deepmd/dpmodel/fitting/general_fitting.py (1)
  • compute_input_stats (225-288)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
  • compute_fitting_input_stat (404-422)
deepmd/dpmodel/fitting/general_fitting.py (1)
deepmd/pt/model/task/fitting.py (1)
  • compute_input_stats (78-157)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
🔇 Additional comments (4)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)

518-534: LGTM! Placeholder method correctly defines the public API.

The method signature and docstring are well-structured. The placeholder implementation follows the correct pattern for a base class that will be overridden by derived classes (e.g., DPAtomicModel).

The docstring correctly references paddle.Tensor for the PaddlePaddle framework and clearly describes the purpose of computing input statistics from packed data.

deepmd/dpmodel/fitting/general_fitting.py (2)

7-7: LGTM! Import addition supports the new method signature.

The Callable import is necessary for the type hint in compute_input_stats method.


225-289: Excellent implementation with numerically stable variance calculation.

The method correctly handles both frame and atomic parameters:

  1. Early return optimization: Skips computation when no parameters are present (lines 245-247).
  2. Lazy evaluation support: Properly handles both immediate and deferred data sources (lines 248-251).
  3. Protection threshold: Prevents division by zero by clamping small standard deviations to the protection value before computing inverse std (lines 258-262, 281-285).
  4. Numerically stable aparam computation: Uses parallel aggregation with sum and sum-of-squares for variance calculation (lines 267-280), which is more stable than the direct approach used for fparam.

The implementation is consistent with the PyTorch counterpart in deepmd/pt/model/task/fitting.py.

source/tests/common/dpmodel/test_fitting_stat.py (1)

69-99: Well-designed test that validates protection threshold behavior.

The test effectively validates the compute_input_stats implementation:

  1. Comprehensive test data: Uses multiple systems with varying numbers of atoms and frames.
  2. Protection threshold validation: The synthetic data includes stds=[2, 0.4, 0.00001] where the third value (0.00001 << 0.01) specifically tests the protection threshold mechanism.
  3. Correct reference calculation: Lines 88-91 properly compute reference inverse standard deviations and clip them at 100, which corresponds to the protection threshold (1.0 / 0.01 = 100).
  4. Proper assertions: Validates both mean and inverse std for both fparam and aparam.

The test ensures that when standard deviation is below the protection threshold, the inverse std is correctly capped to prevent numerical instability.

@Chengqian-Zhang
Copy link
Collaborator Author

I add compute_input_stats function in dpmodel/fitting/general_fitting.py.
But I find there is no descriptor and output bias statistic in atomic_model and model level in python backend, so I also do not implement the fitting stat in atomic_model and model level in python backend.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
source/tests/common/dpmodel/test_fitting_stat.py (2)

41-66: Reduce duplication between _brute_fparam_pt and _brute_aparam_pt

The two brute-force helpers are identical except for the dictionary key they read, which is a good sign but also duplication. You can simplify and make future changes safer by factoring them into a single generic helper.

For example:

-def _brute_fparam_pt(data, ndim):
-    adata = [ii["fparam"] for ii in data]
-    all_data = []
-    for ii in adata:
-        tmp = np.reshape(ii, [-1, ndim])
-        if len(all_data) == 0:
-            all_data = np.array(tmp)
-        else:
-            all_data = np.concatenate((all_data, tmp), axis=0)
-    avg = np.average(all_data, axis=0)
-    std = np.std(all_data, axis=0)
-    return avg, std
-
-
-def _brute_aparam_pt(data, ndim):
-    adata = [ii["aparam"] for ii in data]
-    all_data = []
-    for ii in adata:
-        tmp = np.reshape(ii, [-1, ndim])
-        if len(all_data) == 0:
-            all_data = np.array(tmp)
-        else:
-            all_data = np.concatenate((all_data, tmp), axis=0)
-    avg = np.average(all_data, axis=0)
-    std = np.std(all_data, axis=0)
-    return avg, std
+def _brute_param_pt(data, key, ndim):
+    chunks = [np.reshape(d[key], [-1, ndim]) for d in data]
+    all_data = np.concatenate(chunks, axis=0)
+    avg = np.average(all_data, axis=0)
+    std = np.std(all_data, axis=0)
+    return avg, std
+
+
+def _brute_fparam_pt(data, ndim):
+    return _brute_param_pt(data, "fparam", ndim)
+
+
+def _brute_aparam_pt(data, ndim):
+    return _brute_param_pt(data, "aparam", ndim)

This keeps the intent clear and avoids having to update two almost-identical functions in the future.


69-95: Test logic correctly mirrors protection/clipping, but consider more robust numeric tolerance

The main test correctly mirrors compute_input_stats’ behavior:

  • You use population standard deviation (np.std(..., ddof=0) implicitly) like the implementation.
  • The protection behavior is matched by computing 1.0 / std and clipping inv-std at 100 when protection=1e-2, which is equivalent to clamping std to >= protection before inversion.
  • You validate both fparam and aparam averages and inverse stddevs across multiple systems and shapes, which gives good coverage of the new path.

One potential improvement is numeric robustness: if EnergyFittingNet stores stats in a lower-precision dtype (e.g., float32), np.testing.assert_almost_equal with the default decimal=7 can be a bit tight and lead to flaky failures. You could either relax the tolerance slightly or use explicit rtol/atol via assert_allclose:

-        np.testing.assert_almost_equal(frefa, fitting.fparam_avg)
-        np.testing.assert_almost_equal(frefs_inv, fitting.fparam_inv_std)
-        np.testing.assert_almost_equal(arefa, fitting.aparam_avg)
-        np.testing.assert_almost_equal(arefs_inv, fitting.aparam_inv_std)
+        np.testing.assert_allclose(frefa, fitting.fparam_avg, rtol=1e-5, atol=1e-7)
+        np.testing.assert_allclose(frefs_inv, fitting.fparam_inv_std, rtol=1e-5, atol=1e-7)
+        np.testing.assert_allclose(arefa, fitting.aparam_avg, rtol=1e-5, atol=1e-7)
+        np.testing.assert_allclose(arefs_inv, fitting.aparam_inv_std, rtol=1e-5, atol=1e-7)

This keeps the assertion strict enough to catch real regressions while making the test less sensitive to dtype or minor numerical differences.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c6af0e0 and 14e64e2.

📒 Files selected for processing (1)
  • source/tests/common/dpmodel/test_fitting_stat.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
source/tests/common/dpmodel/test_fitting_stat.py (1)
deepmd/dpmodel/fitting/general_fitting.py (1)
  • compute_input_stats (225-288)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (28)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
🔇 Additional comments (1)
source/tests/common/dpmodel/test_fitting_stat.py (1)

14-38: Data generator shapes line up correctly with compute_input_stats

The synthetic data generator is consistent with GeneralFitting.compute_input_stats expectations: fparam and aparam both have the fitting dimension as the last axis, and your upstream reshape to [-1, ndim] will work for both. Using per-parameter default_rng seeds makes the test deterministic and nicely exercises multiple systems/frames.

No changes needed here from a correctness standpoint.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants